import argparse
import os
import torch

# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

TEXT_DIM = 0
ACOUSTIC_DIM = 0
VISUAL_DIM = 0

def set_dataset_config(dataset_name):
    global TEXT_DIM, ACOUSTIC_DIM, VISUAL_DIM

    dataset_configs = {
        "mosi": {"ACOUSTIC_DIM": 74, "VISUAL_DIM": 47, "TEXT_DIM": 768},
        "mosei": {"ACOUSTIC_DIM": 74, "VISUAL_DIM": 35, "TEXT_DIM": 768}
    }

    config = dataset_configs.get(dataset_name)
    if config:
        ACOUSTIC_DIM = config["ACOUSTIC_DIM"]
        VISUAL_DIM = config["VISUAL_DIM"]
        TEXT_DIM = config["TEXT_DIM"]
    else:
        raise ValueError(f"Invalid dataset name: {dataset_name}")


parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda:2')
parser.add_argument("--seed", type=int, default=128)
#
parser.add_argument("--model", type=str, default="../2.Bert/deberta-v3-base")
parser.add_argument("--dataset", type=str, choices=["mosi", "mosei"], default="mosi")
parser.add_argument("--max_seq_length", type=int, default=50)
#
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--dev_batch_size", type=int, default=128)
parser.add_argument("--test_batch_size", type=int, default=128)
#
parser.add_argument("--n_epochs", type=int, default=30)
parser.add_argument("--learning_rate", type=float, default=1e-5)
parser.add_argument("--gradient_accumulation_step", type=int, default=1)
parser.add_argument("--warmup_proportion", type=float, default=0.1)
#
parser.add_argument("--dropout_prob", help='drop probability for dropout', default=0.3, type=float) 
# 
parser.add_argument('--hidden_dim', default=512, help='hidden dimension for multimodal feature projection', type=int)
parser.add_argument('--dim', default=512, help='dime for mib', type=int)
parser.add_argument('--output_dim', default=256, help='out_dim for mib', type=int)
#
parser.add_argument('--p_alpha', default=1, help='coefficient -- alpha', type=float)
parser.add_argument('--p_beta', default=1e-3, help='coefficient -- beta', type=float)    
parser.add_argument('--p_lambda1', default=0.5, help='coefficient -- lambda1', type=float) 
parser.add_argument('--p_lambda2', default=0.5, help='coefficient -- lambda2', type=float) 

args = parser.parse_args()

DEVICE = torch.device(args.device)